#include "RPCManager.h"
#include "System.h"
#include "async/AsyncTask.h"

#define DEF_RPC_TICK_SIZE (256)

class RPCManager::RPCAsyncTask : public AsyncTask {
public:
    RPCAsyncTask(const std::shared_ptr<std::atomic<uintptr_t>> &hinder,
            const std::function<void(INetStream&, int32, bool)> &cb,
            INetPacket *pck, int32 err, bool eof)
        : hinder_(hinder), cb_(cb), pck_(pck), err_(err), eof_(eof)
    {}
    virtual ~RPCAsyncTask() {
        delete pck_;
    }
    virtual void Finish(AsyncTaskOwner *owner) {
        auto hinder = hinder_->load();
        if (hinder == 0 || hinder == uintptr_t(this)) {
            cb_(pck_ != nullptr ? *pck_ : null_packet_, err_, eof_);
        }
    }
    virtual void ExecuteInAsync() {
    }
private:
    const std::shared_ptr<std::atomic<uintptr_t>> hinder_;
    const std::function<void(INetStream&, int32, bool)> cb_;
    INetPacket *pck_;
    int32 err_;
    bool eof_;
};

std::weak_ptr<AsyncTaskOwner> RPCManager::null_owner_;
ConstNetPacket RPCManager::null_packet_;

RPCManager::RPCManager()
: tick_objs_(DEF_RPC_TICK_SIZE)
, tick_time_(GET_UNIX_TIME)
, request_sn_(1)
{
    static AsyncTaskOwner sNullOwner;
    null_owner_ = sNullOwner.linked_from_this();
}

RPCManager::~RPCManager()
{
    for (auto &pair : requests_) {
        auto &info = *pair.second;
        delete info.trans;
        delete info.pck;
    }
}

void RPCManager::TickObjs()
{
    time_t curTime = GET_UNIX_TIME;
    for (; tick_time_ < curTime; ++tick_time_) {
        std::list<TickInfo> others;
        auto &objs = tick_objs_[tick_time_ % DEF_RPC_TICK_SIZE];
        do {
            std::unique_lock<std::mutex> lock(mutex_);
            for (auto itr = objs.begin(); itr != objs.end();) {
                if (itr->expiry <= tick_time_) {
                    others.splice(others.end(), objs, itr++);
                } else {
                    ++itr;
                }
            }
        } while (0);
        for (auto &info : others) {
            info.expiry = -1;  // tick obj isn't in slot.
            DoReply(nullptr, { info.sn, RPCErrorTimeout, true });
        }
    }
}

void RPCManager::OnRPCReply(INetPacket *pck)
{
    RPCActor::ReplyMetaInfo info = RPCActor::ReadReplyMetaInfo(*pck);
    if (DoReply(pck, info) != 0) {
        delete pck;
    }
}

int RPCManager::DoReply(INetPacket *pck, const RPCActor::ReplyMetaInfo &info)
{
    std::shared_ptr<RequestInfo> requestInfoPtr;
    do {
        std::unique_lock<std::mutex> lock(mutex_);
        auto itr = requests_.find(info.sn);
        if (itr != requests_.end()) {
            if (info.eof) {
                RemoveTickObj(*(requestInfoPtr = std::move(itr->second)));
                requests_.erase(itr);
            } else {
                auto &requestInfo = *(requestInfoPtr = itr->second);
                RelocateTickObj(requestInfo,
                    tick_objs_[requestInfo.slot], requestInfo.itr);
            }
        }
    } while (0);
    if (!requestInfoPtr) {
        return -1;
    }

    const auto &requestInfo = *requestInfoPtr;
    if (info.eof) {
        delete requestInfo.trans;
        delete requestInfo.pck;
    }
    if (!requestInfo.cb) {
        return -1;
    }

    auto task = new RPCAsyncTask(
        requestInfo.hinder, requestInfo.cb, pck, info.err, info.eof);
    if (info.err != RPCErrorNone) {
        uintptr_t expected = 0;
        requestInfo.hinder->compare_exchange_strong(expected, uintptr_t(task));
    }
    if (!requestInfo.owner.expired()) {
        if (requestInfo.owner.lock().get() != null_owner_.lock().get()) {
            requestInfo.owner.lock()->AddTask(task);
            task = nullptr;
        } else {
            TRY_BEGIN {
                task->Finish(nullptr);
            } TRY_END
            CATCH_BEGIN(const IException &e) {
                e.Print();
            } CATCH_END
            CATCH_BEGIN(...) {
            } CATCH_END
        }
    }
    if (task != nullptr) {
        delete task;
    }

    return 0;
}

void RPCManager::SendAllRequests(RPCActor *actor)
{
    std::unique_lock<std::mutex> lock(mutex_);
    for (auto &pair : requests_) {
        auto &info = *pair.second;
        if (info.itr->actor == uintptr_t(actor)) {
            actor->PushRPCPacket(info.trans != nullptr ?
                *info.trans : null_packet_, *info.pck, info.args);
        }
    }
}

void RPCManager::InterruptAllRequests(RPCActor *actor)
{
    std::list<TickInfo> others;
    do {
        std::unique_lock<std::mutex> lock(mutex_);
        for (auto &objs : tick_objs_) {
            if (actor != nullptr) {
                for (auto itr = objs.begin(); itr != objs.end();) {
                    if (itr->actor == uintptr_t(actor)) {
                        others.splice(others.end(), objs, itr++);
                    } else {
                        ++itr;
                    }
                }
            } else {
                others.splice(others.end(), objs);
            }
        }
    } while (0);
    for (auto &info : others) {
        info.expiry = -1;  // tick obj isn't in slot.
        DoReply(nullptr, { info.sn, RPCErrorInterrupt, true });
    }
}

void RPCManager::InterruptRequest(uint64 requestSN)
{
    DoReply(nullptr, { requestSN, RPCErrorInterrupt, true });
}

void RPCManager::PutRequestInfo(RPCActor *actor,
    uint64 requestSN, std::shared_ptr<RequestInfo> &&requestInfoPtr)
{
    std::list<TickInfo> other{ { requestSN, 0, uintptr_t(actor) } };
    RequestInfo &requestInfo = *requestInfoPtr;
    requestInfo.itr = other.begin();
    requestInfo.hinder = std::make_shared<std::atomic<uintptr_t>>(0);

    do {
        std::unique_lock<std::mutex> lock(mutex_);
        RelocateTickObj(requestInfo, other, requestInfo.itr);
        requests_.emplace(requestSN, std::move(requestInfoPtr));
    } while (0);
}

void RPCManager::RemoveTickObj(RequestInfo &info)
{
    if (info.itr->expiry != -1) {
        tick_objs_[info.slot].erase(info.itr);
    }
}

void RPCManager::RelocateTickObj(RequestInfo &info,
    std::list<TickInfo> &other, std::list<TickInfo>::iterator itr)
{
    auto expiry = itr->expiry = info.timeout + tick_time_;
    auto slot = (expiry + 1) % DEF_RPC_TICK_SIZE;
    auto &objs = tick_objs_[slot];
    objs.splice(objs.end(), other, itr);
    info.slot = slot;
}

size_t RPCManager::GetWaitReplyRequests()
{
    std::unique_lock<std::mutex> lock(mutex_);
    return requests_.size();
}
